{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 431
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 8408,
     "status": "ok",
     "timestamp": 1586679815628,
     "user": {
      "displayName": "Xue Jiang",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjfXAo9TmWFYS-GhDq8HOVk3N1FnJiyBPH8h9M=s64",
      "userId": "12544639825510830792"
     },
     "user_tz": -480
    },
    "id": "x-p3iFQxK5ap",
    "outputId": "bc155108-4918-40d4-f1f4-c3e52e68d37b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: transformers in /usr/local/lib/python3.6/dist-packages (2.8.0)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)\n",
      "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.21.0)\n",
      "Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.2)\n",
      "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (from transformers) (0.1.85)\n",
      "Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers) (0.0.38)\n",
      "Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.12.38)\n",
      "Requirement already satisfied: tokenizers==0.5.2 in /usr/local/lib/python3.6/dist-packages (from transformers) (0.5.2)\n",
      "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.38.0)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)\n",
      "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n",
      "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.8)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.1)\n",
      "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n",
      "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.1)\n",
      "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)\n",
      "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.1)\n",
      "Requirement already satisfied: botocore<1.16.0,>=1.15.38 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.15.38)\n",
      "Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.3.3)\n",
      "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.5)\n",
      "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.38->boto3->transformers) (2.8.1)\n",
      "Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.38->boto3->transformers) (0.15.2)\n"
     ]
    }
   ],
   "source": [
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nfMilbCALMfS"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
    "\n",
    "# 加载预训练模型tokenizer (vocabulary)\n",
    "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 10336,
     "status": "ok",
     "timestamp": 1586679817578,
     "user": {
      "displayName": "Xue Jiang",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjfXAo9TmWFYS-GhDq8HOVk3N1FnJiyBPH8h9M=s64",
      "userId": "12544639825510830792"
     },
     "user_tz": -480
    },
    "id": "62HfFUZtLXzO",
    "outputId": "c4af2fa1-4382-4cb4-faa8-92c98ab0c52f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2061, 318, 262, 14162, 1097, 287, 262]\n"
     ]
    }
   ],
   "source": [
    "# 对文本输入进行编码\n",
    "text = \"What is the fastest car in the\"\n",
    "indexed_tokens = tokenizer.encode(text)\n",
    "print(indexed_tokens)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 10330,
     "status": "ok",
     "timestamp": 1586679817579,
     "user": {
      "displayName": "Xue Jiang",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjfXAo9TmWFYS-GhDq8HOVk3N1FnJiyBPH8h9M=s64",
      "userId": "12544639825510830792"
     },
     "user_tz": -480
    },
    "id": "Y-buHhZ_MC_Z",
    "outputId": "d852ce45-0746-4a9d-c99b-b3214ba095a0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 2061,   318,   262, 14162,  1097,   287,   262]])\n"
     ]
    }
   ],
   "source": [
    "# 在PyTorch张量中转换indexed_tokens\n",
    "tokens_tensor = torch.tensor([indexed_tokens])\n",
    "print(tokens_tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "7308ec1c4a4d49668a9679f6f04b966e",
      "645ee7f3700e44e8a0b7d7c272308082",
      "09f58ee47d864d708308146c2c822654",
      "177bf6ac863c4ebdb7890787acc5afe7",
      "749ee62bb7944ee8b91c394cdb0d2a5f",
      "c6814cc0040b40779a182dfca0981c18",
      "414dbe4fca7b404d84b80b59cd648777",
      "daba7aa1819347d7b04b676b7e06a483",
      "e108df794710419994c9e0348cc9b75d",
      "1c8d81161a704b6782bb357cf70ddb30",
      "849b3182a1264a00b5d70b6396fcb0b1",
      "e234695bfb054f1898eb571780ebdbfe",
      "64fa6f05a02f49ae9909f2dea0675f98",
      "8a0ec337c0c14123a1cb569c5409a7c5",
      "0adf720ba2aa4eba940d05f6f4473191",
      "d258da166c57498c9955c4a4a1fa9c4d"
     ]
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 63939,
     "status": "ok",
     "timestamp": 1586679871196,
     "user": {
      "displayName": "Xue Jiang",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjfXAo9TmWFYS-GhDq8HOVk3N1FnJiyBPH8h9M=s64",
      "userId": "12544639825510830792"
     },
     "user_tz": -480
    },
    "id": "LPxTLHXlMhPS",
    "outputId": "3b7155c5-a914-45c6-da18-c851aa57c670"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7308ec1c4a4d49668a9679f6f04b966e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Downloading', max=224, style=ProgressStyle(description_width=…"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e108df794710419994c9e0348cc9b75d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Downloading', max=548118077, style=ProgressStyle(description_…"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GPT2LMHeadModel(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (1): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (2): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (3): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (4): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (5): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (6): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (7): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (8): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (9): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (10): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (11): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "# 加载预训练模型 (weights)\n",
    "model = GPT2LMHeadModel.from_pretrained('gpt2')\n",
    "\n",
    "#将模型设置为evaluation模式，关闭DropOut模块\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 79755,
     "status": "ok",
     "timestamp": 1586679887019,
     "user": {
      "displayName": "Xue Jiang",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjfXAo9TmWFYS-GhDq8HOVk3N1FnJiyBPH8h9M=s64",
      "userId": "12544639825510830792"
     },
     "user_tz": -480
    },
    "id": "1yNdbQVKLa44",
    "outputId": "e5cdb581-fe90-4e68-f96e-b21fc53b1523"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPT2LMHeadModel(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (1): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (2): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (3): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (4): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (5): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (6): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (7): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (8): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (9): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (10): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (11): Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 如果你有GPU，把所有东西都放在cuda上\n",
    "tokens_tensor = tokens_tensor.to('cuda')\n",
    "model.to('cuda')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "v8HAdyfMMobh"
   },
   "outputs": [],
   "source": [
    "# 预测所有的tokens\n",
    "with torch.no_grad():\n",
    "    outputs = model(tokens_tensor)\n",
    "    predictions = outputs[0]\n",
    "\n",
    "# print(outputs)\n",
    "# print(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 80679,
     "status": "ok",
     "timestamp": 1586679887958,
     "user": {
      "displayName": "Xue Jiang",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjfXAo9TmWFYS-GhDq8HOVk3N1FnJiyBPH8h9M=s64",
      "userId": "12544639825510830792"
     },
     "user_tz": -480
    },
    "id": "x4nxgC3JMqvB",
    "outputId": "649083b6-0092-42cc-f3ae-9e146897fc7c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "What is the fastest car in the world\n"
     ]
    }
   ],
   "source": [
    "# 得到预测的单词\n",
    "predicted_index = torch.argmax(predictions[0, -1, :]).item()\n",
    "predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])\n",
    "\n",
    "# 打印预测单词\n",
    "print(predicted_text)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyNhDEa4UttKoiGY0ECZR6aC",
   "collapsed_sections": [],
   "name": "使用GPT-2预测下一个单词.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "09f58ee47d864d708308146c2c822654": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "IntProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "IntProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Downloading: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c6814cc0040b40779a182dfca0981c18",
      "max": 224,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_749ee62bb7944ee8b91c394cdb0d2a5f",
      "value": 224
     }
    },
    "0adf720ba2aa4eba940d05f6f4473191": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "177bf6ac863c4ebdb7890787acc5afe7": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_daba7aa1819347d7b04b676b7e06a483",
      "placeholder": "​",
      "style": "IPY_MODEL_414dbe4fca7b404d84b80b59cd648777",
      "value": " 224/224 [00:01&lt;00:00, 137B/s]"
     }
    },
    "1c8d81161a704b6782bb357cf70ddb30": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "414dbe4fca7b404d84b80b59cd648777": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "645ee7f3700e44e8a0b7d7c272308082": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "64fa6f05a02f49ae9909f2dea0675f98": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "7308ec1c4a4d49668a9679f6f04b966e": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_09f58ee47d864d708308146c2c822654",
       "IPY_MODEL_177bf6ac863c4ebdb7890787acc5afe7"
      ],
      "layout": "IPY_MODEL_645ee7f3700e44e8a0b7d7c272308082"
     }
    },
    "749ee62bb7944ee8b91c394cdb0d2a5f": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "849b3182a1264a00b5d70b6396fcb0b1": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "IntProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "IntProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Downloading: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_8a0ec337c0c14123a1cb569c5409a7c5",
      "max": 548118077,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_64fa6f05a02f49ae9909f2dea0675f98",
      "value": 548118077
     }
    },
    "8a0ec337c0c14123a1cb569c5409a7c5": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c6814cc0040b40779a182dfca0981c18": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "d258da166c57498c9955c4a4a1fa9c4d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "daba7aa1819347d7b04b676b7e06a483": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "e108df794710419994c9e0348cc9b75d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_849b3182a1264a00b5d70b6396fcb0b1",
       "IPY_MODEL_e234695bfb054f1898eb571780ebdbfe"
      ],
      "layout": "IPY_MODEL_1c8d81161a704b6782bb357cf70ddb30"
     }
    },
    "e234695bfb054f1898eb571780ebdbfe": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_d258da166c57498c9955c4a4a1fa9c4d",
      "placeholder": "​",
      "style": "IPY_MODEL_0adf720ba2aa4eba940d05f6f4473191",
      "value": " 548M/548M [00:46&lt;00:00, 11.7MB/s]"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
